#include "SIFT.hpp"
#include <Math/Matrix3x3.hpp>
#include <Image/ImageFilter.hpp>
#include <common.hpp>
using namespace std;

namespace zzz {
void SIFT::RefineKeypoint() {
  Matrix3x3f A; //Hessian
  Vector3f B,X;
  for (size_t o=0; o<keypoint_.size(); o++) {
    int octave_scale=Pow(2,o);
    int nrow=nrow_/octave_scale, ncol=ncol_/octave_scale, nscale=dog_[o].size();
    for (size_t i=0; i<keypoint_[o].size(); i++) {
      //local search
      bool good=true;
      int counter=0;
      while(true) {
        Vector3f &pos=keypoint_[o][i];
#define AT(dr,dc,ds) dog_[o][int(pos[2])+(ds)]->At(int(pos[0])+(dr),int(pos[1])+(dc))
        float Dr=0.5*(AT(1,0,0)-AT(-1,0,0));
        float Dc=0.5*(AT(0,1,0)-AT(0,-1,0));
        float Ds=0.5*(AT(0,0,1)-AT(0,0,-1));
        float Drr = (AT(+1,0,0) + AT(-1,0,0) - 2.0 * AT(0,0,0)) ;
        float Dcc = (AT(0,+1,0) + AT(0,-1,0) - 2.0 * AT(0,0,0)) ;
        float Dss = (AT(0,0,+1) + AT(0,0,-1) - 2.0 * AT(0,0,0)) ;
        float Drc = 0.25 * (AT(+1,+1,0) + AT(-1,-1,0) - AT(-1,+1,0) - AT(+1,-1,0)) ;
        float Drs = 0.25 * (AT(+1,0,+1) + AT(-1,0,-1) - AT(-1,0,+1) - AT(+1,0,-1)) ;
        float Dcs = 0.25 * (AT(0,+1,+1) + AT(0,-1,-1) - AT(0,-1,+1) - AT(0,+1,-1)) ;
#undef AT
        A(0,0)=Drr;
        A(1,1)=Dcc;
        A(2,2)=Dss;
        A(0,1)=Drc; A(1,0)=Drc;
        A(0,2)=Drs; A(2,0)=Drs;
        A(1,2)=Dcs; A(2,1)=Dcs;
        B[0]=-Dr;
        B[1]=-Dc;
        B[2]=-Ds;
        if (!A.Invert()) {
          good=false;
          break;
        }
        X=A*B;
        ZLOG(ZVERBOSE)<<pos<<X<<endl;
        Vector3f dX(0,0,0);
        dX[0]= ((X[0] >  0.6 && pos[0] < nrow-2) ?  1 : 0)
          + ((X[0] < -0.6 && pos[0] > 1) ? -1 : 0) ;
        dX[1]= ((X[1] >  0.6 && pos[1] < ncol-2) ?  1 : 0)
          + ((X[1] < -0.6 && pos[1] > 1) ? -1 : 0) ;
        dX[2]= ((X[2] >  0.6 && pos[2] < nscale-2) ?  1 : 0)
          + ((X[2] < -0.6 && pos[2] > 1) ? -1 : 0) ;
        if (counter++>5 || (dX[0]==0 && dX[1]==0 && dX[2]==0)) {
          //done searching
          //check if value is big enough
          float value=dog_[o][int(pos[2])]->At(int(pos[0]),int(pos[1]));
          value+=-0.5*B.Dot(X);
          if (value<THRESH) {
            good=false;
            break;
          }
          //eliminate edge response
          float score=(Drr+Dcc)*(Drr+Dcc)/(Drr*Dcc-Drc*Drc);
          if (score<0 && score>(EDGETHRESH+1)*(EDGETHRESH+1)/EDGETHRESH) {
            good=false;
            break;
          }
          //out of range
          pos+=X;
          if (pos[0]<0 || pos[0]>nrow-1 || pos[1]<0 || pos[1]>ncol-1 || pos[2]<0 || pos[2]>nscale-1) {
            good=false;
            break;
          }
          break;
        } else {
          pos+=dX;
          continue; //set new init point and search again
        }
        if (good=false) {
          keypoint_[o].erase(keypoint_[o].begin()+i);
          i--;
        }
      }
    }
  }
}

void SIFT::FindLocalMaxima() {
  keypoint_.clear();
  for (size_t o=0; o<dog_.size(); o++) {
    int octave_scale=Pow(2,o);
    int nrow=nrow_/octave_scale, ncol=ncol_/octave_scale, nscale=dog_[o].size();
    vector<Vector3f> curoctave;
    const int neighbor[][3]={
      //r,c,s
      {-1,0,0},{-1,-1,0},{-1,1,0},{0,-1,0},{0,1,0},{1,0,0},{1,-1,0},{1,1,0},
      {-1,0,-1},{-1,-1,-1},{-1,1,-1},{0,-1,-1},{0,1,-1},{1,0,-1},{1,-1,-1},{1,1,-1},{0,0,-1},
      {-1,0,1},{-1,-1,1},{-1,1,1},{0,-1,1},{0,1,1},{1,0,1},{1,-1,1},{1,1,1},{0,0,1}};
      //find maxima
    for (int s=1; s<nscale-2; s++) for (int r=IMAGEEDGEDIST+1; r<nrow-IMAGEEDGEDIST-2; r++) for (int c=IMAGEEDGEDIST+1; c<ncol-IMAGEEDGEDIST-2; c++) {
      float cur=dog_[o][s]->At(r,c);
      bool good=true;
      for (int i=0; i<26; i++) {
        if (cur - dog_[o][s+neighbor[i][2]]->At(r+neighbor[i][0],c+neighbor[i][1]) < THRESH) {
          good=false;
          break;
        }
      }
      if (good) {
        curoctave.push_back(Vector3f(r,c,s));
#ifdef ZZZ_LIB_BOOST
        ReportMaxima(dog_[o][s],r,c);
#endif // ZZZ_LIB_BOOST
      }
    }
    //find minima
    for (int s=1; s<nscale-2; s++) for (int r=IMAGEEDGEDIST+1; r<nrow-IMAGEEDGEDIST-2; r++) for (int c=IMAGEEDGEDIST+1; c<ncol-IMAGEEDGEDIST-2; c++) {
      float cur=dog_[o][s]->At(r,c);
      bool good=true;
      for (int i=0; i<26; i++) {
        if (dog_[o][s+neighbor[i][2]]->At(r+neighbor[i][1],c+neighbor[i][0]) - cur < THRESH) {
          good=false;
          break;
        }
      }
      if (good) {
        curoctave.push_back(Vector3f(r,c,s));
#ifdef ZZZ_LIB_BOOST
        ReportMaxima(dog_[o][s],r,c);
#endif // ZZZ_LIB_BOOST
      }
    }
    keypoint_.push_back(curoctave);
  }
}

void SIFT::BuildDoGOctaves(Image<float> * image) {
  ori_=image;
  BuildGaussianOctaves();
  BuildDoGOctaves();
}

void SIFT::BuildDoGOctaves() {
  ClearOctave(dog_);
  for (size_t i=0; i<gauss_.size(); i++) {
    vector<Image<float> *> cur;
    for (size_t j=1; j<gauss_[i].size(); j++) {
      Image<float> *img=new Image<float>(*(gauss_[i][j]));
      *img-=*(gauss_[i][j-1]);
      cur.push_back(img);
#ifdef ZZZ_LIB_BOOST
      ReportBuildOctave(img);
#endif // ZZZ_LIB_BOOST
    }
    dog_.push_back(cur);
  }
}

Image<float> * SIFT::ScaleInitImage() {
  //scale by omin and blur by initsigma
  int scale;
  if (OMIN<0) scale=-OMIN;
  else scale=OMIN;
  Image<float> scale_img(*ori_);
  for (int i=0; i<scale; i++) 
    scale_img.Resize(scale_img.Rows()*2,scale_img.Cols()*2);
  double sigma = sqrt(SIGMA * SIGMA - INITSIGMA * INITSIGMA * Pow(2,scale*2));
  //Why minus?
  Image<float> *ret=new Image<float>;
  ImageFilter<float>::GaussBlurImage(&scale_img, ret,4, sigma);

  //treat the scaled image as the first one, easier for later process
  ncol_=ori_->Cols()*Pow(2,scale);
  nrow_=ori_->Rows()*Pow(2,scale);
  return ret;
}

void SIFT::BuildGaussianOctaves() {
  ClearOctave(gauss_);
  // start with initial source image
  Image<float> * timage = ScaleInitImage();
  for (int i = 0; i < NOCTAVE; i++) {
    zout<<"BuildGaussianOctave:"<<i<<endl;
    vector<Image<float> *> scales = BuildGaussianScales(timage);
    gauss_.push_back(scales);
    // halve the image size for next iteration
    Image<float> *simage=new Image<float>;
    scales[NSCALE]->HalfImageTo(*simage);
    timage = simage;
    //ATTENTION:
    //timage is push_backed into scales when call BuildGaussianScales()
    //so don't delete timage except for the last useless one
  }
  delete timage;
}

vector<Image<float> *> SIFT::BuildGaussianScales(Image<float> * image) {
  vector<Image<float> *> GScales;
  double k = Pow(2, 1.0/(float)NSCALE);
  GScales.push_back(image);
  for (int i =  1; i < NSCALE + 3; i++) {
    // 2 passes of 1D on original
    float sigma1 = Pow(k, i - 1) * SIGMA;
    float sigma2 = Pow(k, i) * SIGMA;
    float sigma = sqrt(sigma2*sigma2 - sigma1*sigma1);

    Image<float>* dst = new Image<float>;
    ImageFilter<float>::GaussBlurImage(GScales[GScales.size() - 1], dst,4, sigma);
    GScales.push_back(dst);
#ifdef ZZZ_LIB_BOOST
    ReportBuildGaussian(dst);
#endif // ZZZ_LIB_BOOST
  }
  return GScales;
}

void SIFT::ClearOctave(Octave &octave) {
  for (size_t i=0; i<octave.size(); i++) for (size_t j =0; j<octave[i].size(); j++)
    delete octave[i][j];
  octave.clear();
}
}  // namespace zzz
